require(Matrix)
require(stringi)
require(quanteda)
require(newsmap)

sample_list <- function(x, n = 100) {
    lapply(x, function(x) sample(x, min(n, length(x))))
}

insert_noise <- function(x, noise) {
    i <- sample(length(x), 1)
    add <- sample(noise, 1)
    x[[i]] <- c(x[[i]], add)
    attr(x, "value") <- add
    attr(x, "key") <- names(x[i])
    return(x)
}

grow_list <- function(x, y, sequential = FALSE) {
    stopifnot(length(x) == length(y))
    if (all(lengths(x) == 0)) {
        names(x) <- names(y)
        for (i in seq_len(length(y))) {
            new <- setdiff(y[[i]], x[[i]])
            if (sequential) {
                add <- head(new, 1)
            } else {
                add <- sample(new, min(length(new), 1))
            }
            x[[i]] <- add
        }
        attr(x, "value") <- unlist(x, use.names = FALSE)
        attr(x, "key") <- names(x)
    } else {
        is_exhausted <- mapply(function(x, y) length(setdiff(y, x)) == 0, x, y)
        if (all(is_exhausted)) {
            warning("No more words to sample", immediate. = TRUE, call. = FALSE)
            attr(x, "value") <- character()
            attr(x, "key") <- character()
        } else {
            i <- which(!is_exhausted)
            if (sequential) {
                i <- min(i)
            } else {
                if (length(i) > 1)
                    i <- sample(i, 1)
            }
            new <- setdiff(y[[i]], x[[i]])
            if (sequential) {
                add <- head(new, 1)
            } else {
                add <- sample(new, 1)
            }
            x[[i]] <- c(x[[i]], add)
            attr(x, "value") <- add
            attr(x, "key") <- names(x[i])
        }
    }
    return(x)
}

append_list <- function(x, y, i, j) {
    stopifnot(length(x) == length(y))
    x[[i]] <- c(x[[i]], y[[i]][j])
    attr(x, "value") <- y[[i]][j]
    attr(x, "key") <- names(x[i])
    return(x)
}

clean_list <- function(x, remove) {
    lapply(x, function(x) setdiff(x, remove))
}

get_delta <- function(x, count = "boolean") {
    co <- as.matrix(fcm(x, count = count, tri = FALSE))
    d <- diag(co)
    n <- ndoc(x)
    mean(sqrt((d / n)))
}

#' @param x a dfm for labels
get_coverage <- function(x) {
    x <- dfm_weight(x, "boolean")
    d <- colSums(x) / ndoc(x)
    return(mean(d))
}

#' @param x a dfm for labels
get_coverage2 <- function(x, y) {
    x <- dfm_weight(x, "boolean")
    p <- prod(dim(x))
    x <- x[rowSums(y) > 0,]
    x <- dfm_trim(x)
    print(p)
    d <- prod(dim(x)) / p
    print(d)
    return(d)
}

group_topics <- function(x, y) {
    result <- matrix(NA, nrow = nfeat(y), ncol = nfeat(x), 
                     dimname = list(featnames(y), featnames(x)))
    for (i in seq_len(nfeat(y))) {
        result[i,] <- colSums(dfm_subset(x, rowSums(y[,i]) > 0))
    }
    return(as.dfm(result))
}

#' @param x a dfm for features
#' @param y a dfm for labels
#' @param smooth a numeric value for smoothing to include all the features
get_entropy <- function(x, y, smooth = 1, weight = FALSE) {
    newsmap::afe(x, y, smooth)
}

plot_path2 <- function(x){
    
    layout(matrix(c(1, 1, 1, 2, 3), 5, 1, byrow = TRUE))
    par(cex = 0.9)
    par(mar = c(2.1, 4.1, 13.1, 2.1))
    
    # Entropy
    x$e <- (x$e2 / x$e2[1]) - 1
    x$e <- x$e * 100
    plot(seq_len(nrow(x)), x$e, type = "p", xlab = "", pch = ifelse(x$noise, 16, 1),
         ylab = "AFE (% diff)", xaxt = "n")
    axis(1, seq_len(nrow(x)), seq_len(nrow(x)) - 1)
    grid(nx = NA, ny = NULL)
    abline(v = seq(1, nrow(x), by = 2), col = "lightgray", lty = "dotted")
    points(seq_len(nrow(x)), x$e, pch = ifelse(x$noise, 16, 1))
    lines(seq_len(nrow(x)), x$e, type = "s")
    
    x$topic_added <- stri_trans_toupper(stri_sub(x$topic_added, 1, 2))
    label <- ifelse(is.na(x$topic_added), "", paste0(x$topic_added, ": ", x$word_added))
    label[1] <- "[knowledge-based]"
    axis(3, seq_along(label), label, las = 2)
    
    par(mar = c(0.6, 4.1, 0.6, 2.1))
    
    # Coverage
    x$d2 <- x$d2 * 100
    plot(x$d2, type = "n", xlab = "", ylab = "Coverage (%)", xaxt = "n")
    grid(nx = NA, ny = NULL)
    abline(v = seq(1, nrow(x), by = 2), col = "lightgray", lty = "dotted")
    lines(x$d2, type = "s")
    points(x$d2, pch = 1)

    # F1
    plot(x$f1, type = "n", xlab = "", ylab = "F1", , xaxt = "n")
    grid(nx = NA, ny = NULL)
    abline(v = seq(1, nrow(x), by = 2), col = "lightgray", lty = "dotted")
    lines(x$f1, type = "s")
    points(x$f1, pch = 1)
    
    # Reset to default
    par(mar = c(5.1, 4.1, 1.1, 2.1))
    layout(matrix(1, 1, 1, byrow = TRUE))
    par(cex = 0.9)
    
}

get_initial <- function(x, n = 1) {
    result <- lapply(x, function(x) head(x, n))
    attr(result, "value") <- NA
    attr(result, "key") <- NA
    return(result)
}

compute_f1 <- function(x) {
    2 * ((x$p * x$r) / (x$p + x$r))
}

test_accuracy <- function(x) {
    result <- accuracy(x$topic_human, x$topic)
    result[intersect(c("greeting", "un", "security", "human", "democracy", "development"), rownames(result)),]
}

# topic-feature matrix
tfm <- function(x, dictionary, levels = 1:5,
                valuetype = c("glob", "regex", "fixed"),
                case_insensitive = TRUE,
                weight = 500, scheme = c("relative", "absolute"),
                residual = TRUE) {
    
    valuetype <- match.arg(valuetype)
    scheme <- match.arg(scheme)
    ids <- quanteda:::pattern2list(dictionary, featnames(x),
                                   valuetype, case_insensitive,
                                   x@meta$object$concatenator, levels)
    key <- attr(ids, "key")
    ids <- ids[lengths(ids) == 1]
    id_key <- match(names(ids), key)
    id <- unlist(ids, use.names = FALSE)
    if (residual)
        key <- c(key, "")
    if (scheme == "relative")
        weight <- weight * colSums(x)[id]
    result <- sparseMatrix(i = id_key,
                           j = id,
                           x = weight,
                           dims = c(length(key), nfeat(x)),
                           dimnames = list(key, featnames(x)))
    as.dfm(result)
}

textmodel_slda <- function(x, y) {
    dtm <- convert(x, "topicmodels")
    dtm_seed <- quanteda:::dfm2dtm(y, omit_empty = FALSE)
    ctr <- list(alpha = 0.1, best = TRUE,
                verbose = 500, burnin = 500, iter = 100, thin = 100, prefix = character())
    #slda <- LDA(dtm, k = dtm_seed$nrow, method = "Gibbs", seedwords = dtm_seed, control = ctr)
    LDA(dtm, k = dtm_seed$nrow, method = "Gibbs", seedwords = dtm_seed)
}


# https://github.com/bstewart/stm/blob/7a8206c24d566635cc0f74f39fc1882d5e846996/R/exclusivity.R
frex <- function(x, w = 0.5) {
    post <- posterior(x)
    beta <- post$terms
    
    beta_log <- log(beta)
    beta_t <- t(beta_log)
    mat <- beta_t / rowSums(beta_t)
    
    ex <- apply(mat, 2, rank) / nrow(mat)
    fr <- apply(beta_t, 2, rank) / nrow(mat)
    frex <- t(1 / (w /ex + (1 - w) / fr))
    apply(frex, 1, function(x) head(names(sort(x, decreasing = TRUE)), 20))
}

compact_terms <- function(x, n = 10) {
    lis <- as.list(as.data.frame(x))
    t(as.data.frame(lapply(lis, function(x) paste(head(x, n), collapse = ", "))))
}
